﻿#pragma kernel ResetNormals
#pragma kernel UpdateNormals
#pragma kernel UpdateEdgeNormals
#pragma kernel OrientationFromNormals

#include "InterlockedUtils.cginc"
#include "MathUtils.cginc"

StructuredBuffer<int> deformableTriangles;
StructuredBuffer<float2> deformableTriangleUVs;

StructuredBuffer<int> deformableEdges;

StructuredBuffer<float4> renderablePositions;
StructuredBuffer<float4> velocities;
StructuredBuffer<float4> wind;
StructuredBuffer<int> phases;

RWStructuredBuffer<float4> normals;

RWStructuredBuffer<uint4> normalsInt;
RWStructuredBuffer<uint4> tangentsInt;
RWStructuredBuffer<quaternion> renderableOrientations;

// Variables set from the CPU
uint normalsCount;
uint triangleCount;
uint edgeCount;

void AccumulateNormal(in int index, in float4 delta)
{
    InterlockedAddFloat(normalsInt, index, 0, delta.x);
    InterlockedAddFloat(normalsInt, index, 1, delta.y);
    InterlockedAddFloat(normalsInt, index, 2, delta.z);
}

void AccumulateTangent(in int index, in float4 delta)
{
    InterlockedAddFloat(tangentsInt, index, 0, delta.x);
    InterlockedAddFloat(tangentsInt, index, 1, delta.y);
    InterlockedAddFloat(tangentsInt, index, 2, delta.z);
    InterlockedAddFloat(tangentsInt, index, 3, delta.w);
}

[numthreads(128, 1, 1)]
void ResetNormals (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= normalsCount) return;
    
    // leave fluid and softbody normalsInt intact.
    if ((phases[i] & (int)PHASE_FLUID) == 0 && normals[i].w >= 0)
    {
        normals[i] = FLOAT4_ZERO;
        tangentsInt[i] = asuint(FLOAT4_ZERO);
    }
}

[numthreads(128, 1, 1)]
void UpdateNormals (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= triangleCount) return;
    
    int p1 = deformableTriangles[i*3];
    int p2 = deformableTriangles[i*3 + 1];
    int p3 = deformableTriangles[i*3 + 2];

    float2 w1 = deformableTriangleUVs[i*3];
    float2 w2 = deformableTriangleUVs[i*3 + 1];
    float2 w3 = deformableTriangleUVs[i*3 + 2];

    float4 v1 = renderablePositions[p1];
    float4 v2 = renderablePositions[p2];
    float4 v3 = renderablePositions[p3];

    float3 m1 = (v2 - v1).xyz;
    float3 m2 = (v3 - v1).xyz;

    float2 s = w2 - w1;
    float2 t = w3 - w1;

    float4 normal  = float4(cross(m1, m2), 0);
    float4 tangent = FLOAT4_ZERO;

    float area = s.x * t.y - t.x * s.y;

    if (abs(area) > EPSILON)
    {
        tangent = float4(t.y * m1.x - s.y * m2.x,
                         t.y * m1.y - s.y * m2.y,
                         t.y * m1.z - s.y * m2.z, 0) / area;
    }

    AccumulateNormal(p1,normal);
    AccumulateNormal(p2,normal);
    AccumulateNormal(p3,normal);

    AccumulateTangent(p1,tangent);
    AccumulateTangent(p2,tangent);
    AccumulateTangent(p3,tangent);
}

[numthreads(128, 1, 1)]
void UpdateEdgeNormals (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= edgeCount) return;

    int p1 = deformableEdges[i * 2];
    int p2 = deformableEdges[i * 2 + 1];

    float4 edge = renderablePositions[p2] - renderablePositions[p1];
    float4 avgWind = (velocities[p1] + velocities[p2]) * 0.5f - (wind[p1] + wind[p2]) * 0.5f;
    float denom =  dot(edge, edge);
    float4 normal = avgWind - (denom < EPSILON ? FLOAT4_ZERO : edge * dot(avgWind, edge) / denom);

    AccumulateNormal(p1,normal);
    AccumulateNormal(p2,normal);
}

[numthreads(128, 1, 1)]
void OrientationFromNormals (uint3 id : SV_DispatchThreadID)
{
    uint i = id.x;
    if (i >= normalsCount) return;

    if ((phases[i] & (int)PHASE_FLUID) == 0 && normals[i].w >= 0) // not fluid or softbody (no SDF stored))
    {
        float4 normal = normals[i];
        float4 tangent = asfloat(tangentsInt[i]);

        if (dot(normal, normal) > EPSILON &&
            dot(tangent, tangent) > EPSILON)
        {
            normals[i] = normalizesafe(normal);
            tangentsInt[i] = asuint(normalizesafe(tangent));

            // particle orientation from normal/tangent:
            renderableOrientations[i] = q_look_at(normals[i].xyz, asfloat(tangentsInt[i]).xyz);
        }
    }
}
